{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'ind': 24, 'activity_label': 'Roof shingle removal', 'ctx_a': 'A man is sitting on a roof.', 'ctx_b': 'he', 'ctx': 'A man is sitting on a roof. he', 'split': 'val', 'split_type': 'indomain', 'label': 3, 'endings': ['is using wrap to wrap a pair of skis.', 'is ripping level tiles off.', \"is holding a rubik's cube.\", 'starts pulling up roofing on a roof.'], 'source_id': 'activitynet~v_-JhWjGDPHMY'}\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "sys.path.insert(0, \"../\")\n",
    "from hellaswag import download, render_example, iterate_examples\n",
    "\n",
    "hs = iterate_examples(\"val\") \n",
    "example = next(hs)\n",
    "print(example)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "({'label': 3,\n",
       "  'ctx_tokens': [32, 582, 318, 5586, 319, 257, 9753, 13, 339],\n",
       "  'ending_tokens': [[318,\n",
       "    1262,\n",
       "    14441,\n",
       "    284,\n",
       "    14441,\n",
       "    257,\n",
       "    5166,\n",
       "    286,\n",
       "    1341,\n",
       "    271,\n",
       "    13],\n",
       "   [318, 34759, 1241, 19867, 572, 13],\n",
       "   [318, 4769, 257, 6437, 1134, 338, 23441, 13],\n",
       "   [4940, 10427, 510, 9753, 278, 319, 257, 9753, 13]]},\n",
       " tensor([[   32,   582,   318,  5586,   319,   257,  9753,    13,   339,   318,\n",
       "           1262, 14441,   284, 14441,   257,  5166,   286,  1341,   271,    13],\n",
       "         [   32,   582,   318,  5586,   319,   257,  9753,    13,   339,   318,\n",
       "          34759,  1241, 19867,   572,    13,     0,     0,     0,     0,     0],\n",
       "         [   32,   582,   318,  5586,   319,   257,  9753,    13,   339,   318,\n",
       "           4769,   257,  6437,  1134,   338, 23441,    13,     0,     0,     0],\n",
       "         [   32,   582,   318,  5586,   319,   257,  9753,    13,   339,  4940,\n",
       "          10427,   510,  9753,   278,   319,   257,  9753,    13,     0,     0]]),\n",
       " tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
       "         [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0],\n",
       "         [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],\n",
       "         [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]),\n",
       " 3)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "rendered = render_example(example)\n",
    "rendered"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAhYAAACOCAYAAABt7UHUAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy80BEi2AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAYkElEQVR4nO3de1CTZ74H8G8gFwKNVFGEiMQ7rBfwznpZ3VYG8DgK6lZqrYvFdk47YSuyy1q3q2i7LdKuVmsdrC7ai0tr2yNqdQpFVrDOVlGQFXddvNRjVRSqK3cKIXnOHx04IlaT8ISG+P3MZJDXl9/3edAn+eXNm7wKIYQAERERkQRuP/UAiIiIyHWwsSAiIiJp2FgQERGRNGwsiIiISBo2FkRERCQNGwsiIiKSho0FERERSaPsyjCLxYLy8nLodDooFIqujCYiIiI7CSFQW1sLvV4PN7f7H5Po0saivLwc/fv378pIIiIikuTKlSsICAi47z5d2ljodDoAwFT8F5RQWfUzSq0K8RnzsGPpHrQ0mhwyLkdnuMIcmOE89e3NyDpXalOGqUWNQyWrET76FaiUzfYM86HIcIU5MMN56jtrRk2dBYax/9v2OH4/XdpYtL78oYQKSoV1jYVKoYKnpydUChXgoFdPHJ3hCnNghvPUtzejh862U6pMLe7w9PRED507VErHnI7lChmuMAdmOE99Z8+w5jQGnrxJRERE0rCxICIiImnYWBAREZE0bCyIiIhIGjYWREREJA0bCyIiIpLGrsZiy5YtGDBgADw8PBAWFobCwkLZ4yIiIqJuyObGYvfu3UhKSkJKSgqKi4sRGhqKyMhIVFZWOmJ8RERE1I3Y/AFZGzZswHPPPYdnnnkGALB161YcPHgQO3bswEsvvdRu36amJjQ1NbV9X1NT80OoVvXDB/tYQaVVtvvqCI7OcIU5MMN56tubYWrR2JTRYta0++oIrpDhCnNghvPUd9YMU4vZ6toKIYSwdufm5mZ4enris88+Q0xMTNv2uLg4VFVVYd++fe32X7NmDdauXduhTmZmJjw9Pa0eJBEREf10Ghoa8NRTT6G6uho9evS47742PaW6efMmzGYz+vbt225737598e9//7vD/itXrkRSUlLb9zU1Nejfvz92LN1j0xGL+Iz52LH0f2BqbLFluFZzdIYrzIEZzlPf3oysstM2ZbSYNW3XElC6Nz34B+zgChmuMAdmOE99Z82oqbX+iIVDrxWi0Wig0XQ8zNLSaLL5GgqmxhaYHHRBp67KcIU5MMN56tuaoVLadweldG+y+2cfpgxXmAMznKe+s2WolBara9p08mbv3r3h7u6OioqKdtsrKirg5+dnSykiIiJyQTY1Fmq1GuPGjUNeXl7bNovFgry8PEyaNEn64IiIiKh7sfmlkKSkJMTFxWH8+PGYOHEiNm7ciPr6+rZ3iRAREdHDy+bGIjY2Ft999x1Wr16NGzduYPTo0cjOzu5wQicRERE9fOw6eTMhIQEJCQmyx0JERETdHK8VQkRERNKwsSAiIiJp2FgQERGRNGwsiIiISBqHfvImERG5rkj9aJv2V2lV+O9MYG5QiMM+ldbRGa4wB3syWoQJwDdW1eYRCyIiIpKGjQURERFJw8aCiIiIpGFjQURERNKwsSAiIiJp2FgQERGRNGwsiIiISBo2FkRERCQNGwsiIiKSho0FERERScPGgoiIiKRhY0FERETSsLEgIiIiadhYEBERkTRsLIiIiEgaNhZEREQkDRsLIiIikoaNBREREUnDxoKIiIikYWNBRERE0rCxICIiImnYWBAREZE0bCyIiIhIGjYWREREJI3ypx4AERF1TznlJTbtb2rRILsoFlllp6FSNjlkTPZkROpHO2QsDysesSAiIiJp2FgQERGRNGwsiIiISBo2FkRERCQNGwsiIiKSho0FERERScPGgoiIiKRhY0FERETSsLEgIiIiaWxqLFJTUzFhwgTodDr4+voiJiYGZWVljhobERERdTM2NRYFBQUwGo04duwYcnNzYTKZEBERgfr6ekeNj4iIiLoRm64Vkp2d3e779957D76+vigqKsK0adM67N/U1ISmpv//rPaampofQrUqqBQqqzJVWmW7r47g6AxXmAMznKe+vRmmFo1NGS1mTbuvjuAKGa4wh4c9Q6W17vHoh32dc307PEMAaLRuV4UQQtg3LODChQsYOnQoSktLMXLkyA5/v2bNGqxdu7bD9szMTHh6etobS0RERF2ooaEBTz31FKqrq9GjR4/77mt3Y2GxWDBnzhxUVVXh6NGj99znXkcs+vfvj3Dtr2w6YhGfMR87lv4PTI0t9gz1J89whTkww3nq25uRVXbapowWswaHSlYjfPQrULo75kqUrpDhCnN42DPmBoVYXd9Z17ejM0zChEONn1nVWNh9nMVoNOLMmTM/2lQAgEajgUbT8XBUS6MJUNiWZ2psganRZOswnSrDFebADOepb2uGvZepVro3OewS166U4QpzeFgz7Fmnzra+HZ3RIqwfh12NRUJCAg4cOIAjR44gICDAnhJERETkgmxqLIQQ+M1vfoOsrCzk5+dj4MCBjhoXERERdUM2NRZGoxGZmZnYt28fdDodbty4AQDw9vaGVqt1yACJiIio+7DpcyzS09NRXV2NX/7yl/D392+77d6921HjIyIiom7E5pdCiIiIiH4MrxVCRERE0rCxICIiImnYWBAREZE0bCyIiIhIGsdd4YSIiKgbyCkvsXpfU4sG2UWxyCo7bfUne0bqR9s3sG6KRyyIiIhIGjYWREREJA0bCyIiIpKGjQURERFJw8aCiIiIpGFjQURERNKwsSAiIiJp2FgQERGRNGwsiIiISBo2FkRERCQNGwsiIiKSho0FERERScPGgoiIiKRhY0FERETSsLEgIiIiadhYEBERkTRsLIiIiEgaNhZEREQkDRsLIiIikoaNBREREUnDxoKIiIikYWNBRERE0rCxICIiImmUXRkmhAAAtMAECGt/CGhoaIBJmNAiTA4amIMzXGEOzHCe+nZm1NRabIowtZjR0NCAmlozVErbfvZhynCFOTDDsfVtvh9wwvuQFvywT+vj+P0ohDV7SXL16lX079+/q+KIiIhIoitXriAgIOC++3RpY2GxWFBeXg6dTgeFQmHVz9TU1KB///64cuUKevTo4ZBxOTrDFebADOepzwznynCFOTDDeeo7a4YQArW1tdDr9XBzu/9ZFF36Uoibm9sDO50f06NHD4f9grsqwxXmwAznqc8M58pwhTkww3nqO2OGt7e3Vfvx5E0iIiKSho0FERERSeP0jYVGo0FKSgo0Gk23zXCFOTDDeeozw7kyXGEOzHCe+q6Q0aUnbxIREZFrc/ojFkRERNR9sLEgIiIiadhYEBERkTRsLIiIiEgaNhZEREQkjVM3Flu2bMGAAQPg4eGBsLAwFBYWSq1/5MgRzJ49G3q9HgqFAnv37pVaPzU1FRMmTIBOp4Ovry9iYmJQVlYmNSM9PR0hISFtn542adIkfPHFF1Iz7rRu3TooFAokJiZKrbtmzRooFIp2t+DgYKkZ165dw9NPPw0fHx9otVqMGjUKJ0+elFZ/wIABHeagUChgNBqlZZjNZqxatQoDBw6EVqvF4MGD8eqrr1p1YSBb1NbWIjExEQaDAVqtFpMnT8aJEyfsrvegtSaEwOrVq+Hv7w+tVovw8HCcP39eWv09e/YgIiICPj4+UCgUKCkpkToHk8mEFStWYNSoUfDy8oJer8evf/1rlJeXS8sAflgnwcHB8PLyQs+ePREeHo7jx49LzbjT888/D4VCgY0bN0rNWLJkSYd1EhUVJXUOZ8+exZw5c+Dt7Q0vLy9MmDAB3377rbSMe611hUKBN998U1pGXV0dEhISEBAQAK1Wi+HDh2Pr1q1W17cmo6KiAkuWLIFer4enpyeioqJsWnv34rSNxe7du5GUlISUlBQUFxcjNDQUkZGRqKyslJZRX1+P0NBQbNmyRVrNOxUUFMBoNOLYsWPIzc2FyWRCREQE6uvrpWUEBARg3bp1KCoqwsmTJ/H4448jOjoa//znP6VltDpx4gTeffddhISESK8NACNGjMD169fbbkePHpVW+/bt25gyZQpUKhW++OIL/Otf/8L69evRs2dPaRknTpxoN/7c3FwAwBNPPCEtIy0tDenp6XjnnXdw9uxZpKWl4Y033sDmzZulZQDAs88+i9zcXHz44YcoLS1FREQEwsPDce3aNbvqPWitvfHGG3j77bexdetWHD9+HF5eXoiMjMT3338vpX59fT2mTp2KtLQ0u8b/oIyGhgYUFxdj1apVKC4uxp49e1BWVoY5c+ZIywCAYcOG4Z133kFpaSmOHj2KAQMGICIiAt999520jFZZWVk4duwY9Hq9TXOwNiMqKqrdevnoo4+k1b948SKmTp2K4OBg5Ofn4/Tp01i1ahU8PDykZdw59uvXr2PHjh1QKBSYP3++tIykpCRkZ2dj165dOHv2LBITE5GQkID9+/dLyRBCICYmBt988w327duHU6dOwWAwIDw8vHOPU8JJTZw4URiNxrbvzWaz0Ov1IjU11SF5AERWVpZDareqrKwUAERBQYFDc3r27Cn+8pe/SK1ZW1srhg4dKnJzc8X06dPFsmXLpNZPSUkRoaGhUmveacWKFWLq1KkOq38vy5YtE4MHDxYWi0VazVmzZon4+Ph22+bNmycWLVokLaOhoUG4u7uLAwcOtNs+duxY8fLLL3e6/t1rzWKxCD8/P/Hmm2+2bauqqhIajUZ89NFHna5/p0uXLgkA4tSpUzbXtTajVWFhoQAgLl++7LCM6upqAUAcOnRIasbVq1dFv379xJkzZ4TBYBBvvfWWXfV/LCMuLk5ER0fbXfNB9WNjY8XTTz8tpf6PZdwtOjpaPP7441IzRowYIV555ZV22zqzDu/OKCsrEwDEmTNn2raZzWbRp08fsX37drsyhBDCKY9YNDc3o6ioCOHh4W3b3NzcEB4ejq+//vonHFnnVFdXAwB69erlkPpmsxkff/wx6uvrMWnSJKm1jUYjZs2a1e7fRLbz589Dr9dj0KBBWLRokU2HLR9k//79GD9+PJ544gn4+vpizJgx2L59u7T6d2tubsauXbsQHx9v9ZV8rTF58mTk5eXh3LlzAIB//OMfOHr0KGbOnCkto6WlBWazucOzO61WK/UoUqtLly7hxo0b7f5veXt7IywsrNuvd4VCgUcffdQh9Zubm7Ft2zZ4e3sjNDRUWl2LxYLFixcjOTkZI0aMkFb3bvn5+fD19UVQUBBeeOEF3Lp1S0pdi8WCgwcPYtiwYYiMjISvry/CwsKkv9R9p4qKChw8eBBLly6VWnfy5MnYv38/rl27BiEEDh8+jHPnziEiIkJK/aamJgBot9bd3Nyg0Wg6tdadsrG4efMmzGYz+vbt22573759cePGjZ9oVJ1jsViQmJiIKVOmYOTIkVJrl5aW4pFHHoFGo8Hzzz+PrKwsDB8+XFr9jz/+GMXFxUhNTZVW825hYWF47733kJ2djfT0dFy6dAm/+MUvUFtbK6X+N998g/T0dAwdOhQ5OTl44YUX8OKLL+L999+XUv9ue/fuRVVVFZYsWSK17ksvvYQnn3wSwcHBUKlUGDNmDBITE7Fo0SJpGTqdDpMmTcKrr76K8vJymM1m7Nq1C19//TWuX78uLadV65p2pfX+/fffY8WKFVi4cKH0q1MeOHAAjzzyCDw8PPDWW28hNzcXvXv3llY/LS0NSqUSL774orSad4uKisIHH3yAvLw8pKWloaCgADNnzoTZbO507crKStTV1WHdunWIiorCl19+iblz52LevHkoKCiQMPqO3n//feh0OsybN09q3c2bN2P48OEICAiAWq1GVFQUtmzZgmnTpkmpHxwcjMDAQKxcuRK3b99Gc3Mz0tLScPXq1U6t9S69bPrDzGg04syZMw55xhcUFISSkhJUV1fjs88+Q1xcHAoKCqQ0F1euXMGyZcuQm5tr0+uTtrrzGXdISAjCwsJgMBjwySefSHkWYLFYMH78eLz++usAgDFjxuDMmTPYunUr4uLiOl3/bhkZGZg5c6Zdr0/fzyeffIK//vWvyMzMxIgRI1BSUoLExETo9Xqp8/jwww8RHx+Pfv36wd3dHWPHjsXChQtRVFQkLcNVmUwmLFiwAEIIpKenS6//2GOPoaSkBDdv3sT27duxYMECHD9+HL6+vp2uXVRUhE2bNqG4uFjqkba7Pfnkk21/HjVqFEJCQjB48GDk5+djxowZnaptsVgAANHR0Vi+fDkAYPTo0fj73/+OrVu3Yvr06Z2qfy87duzAokWLpN9Hbt68GceOHcP+/fthMBhw5MgRGI1G6PV6KUePVSoV9uzZg6VLl6JXr15wd3dHeHg4Zs6c2akTwp3yiEXv3r3h7u6OioqKdtsrKirg5+f3E43KfgkJCThw4AAOHz6MgIAA6fXVajWGDBmCcePGITU1FaGhodi0aZOU2kVFRaisrMTYsWOhVCqhVCpRUFCAt99+G0qlUsozjHt59NFHMWzYMFy4cEFKPX9//w6N1s9+9jOpL7e0unz5Mg4dOoRnn31Weu3k5OS2oxajRo3C4sWLsXz5culHkwYPHoyCggLU1dXhypUrKCwshMlkwqBBg6TmAGhb066w3lubisuXLyM3N1f60QoA8PLywpAhQ/Dzn/8cGRkZUCqVyMjIkFL7q6++QmVlJQIDA9vW++XLl/Hb3/4WAwYMkJJxL4MGDULv3r2lrPfevXtDqVR22Xr/6quvUFZWJn29NzY24g9/+AM2bNiA2bNnIyQkBAkJCYiNjcWf//xnaTnjxo1DSUkJqqqqcP36dWRnZ+PWrVudWutO2Vio1WqMGzcOeXl5bdssFgvy8vKknzvgSEIIJCQkICsrC3/7298wcODALsm1WCxtr5111owZM1BaWoqSkpK22/jx47Fo0SKUlJTA3d1dSs7d6urqcPHiRfj7+0upN2XKlA5v9T137hwMBoOU+nfauXMnfH19MWvWLOm1Gxoa4ObWftm6u7u3PUuTzcvLC/7+/rh9+zZycnIQHR0tPWPgwIHw8/Nrt95rampw/PjxbrXeW5uK8+fP49ChQ/Dx8emSXJnrffHixTh9+nS79a7X65GcnIycnBwpGfdy9epV3Lp1S8p6V6vVmDBhQpet94yMDIwbN07qeS7AD/+fTCZTl613b29v9OnTB+fPn8fJkyc7tdad9qWQpKQkxMXFYfz48Zg4cSI2btyI+vp6PPPMM9Iy6urq2nXIly5dQklJCXr16oXAwMBO1zcajcjMzMS+ffug0+naXi/29vaGVqvtdH0AWLlyJWbOnInAwEDU1tYiMzMT+fn50u4EdDpdh3NCvLy84OPjI/Vckd/97neYPXs2DAYDysvLkZKSAnd3dyxcuFBK/eXLl2Py5Ml4/fXXsWDBAhQWFmLbtm3Ytm2blPqtLBYLdu7cibi4OCiV8pfX7Nmz8dprryEwMBAjRozAqVOnsGHDBsTHx0vNycnJgRACQUFBuHDhApKTkxEcHGz3+nvQWktMTMSf/vQnDB06FAMHDsSqVaug1+sRExMjpf5//vMffPvtt22fK9H6oOPn52f1UZH7Zfj7++NXv/oViouLceDAAZjN5rb13qtXL6jV6k5n+Pj44LXXXsOcOXPg7++PmzdvYsuWLbh27ZpNb2l+0O/q7oZIpVLBz88PQUFBUjJ69eqFtWvXYv78+fDz88PFixfx+9//HkOGDEFkZKSUOSQnJyM2NhbTpk3DY489huzsbHz++efIz8+XMofWx4eamhp8+umnWL9+vdV1bcmYPn06kpOTodVqYTAYUFBQgA8++AAbNmyQlvHpp5+iT58+CAwMRGlpKZYtW4aYmJjOnSBq9/tJusDmzZtFYGCgUKvVYuLEieLYsWNS6x8+fFgA6HCLi4uTUv9etQGInTt3SqkvhBDx8fHCYDAItVot+vTpI2bMmCG+/PJLafXvxRFvN42NjRX+/v5CrVaLfv36idjYWHHhwgWpGZ9//rkYOXKk0Gg0Ijg4WGzbtk1qfSGEyMnJEQBEWVmZ9NpCCFFTUyOWLVsmAgMDhYeHhxg0aJB4+eWXRVNTk9Sc3bt3i0GDBgm1Wi38/PyE0WgUVVVVdtd70FqzWCxi1apVom/fvkKj0YgZM2bY9Dt8UP2dO3fe8+9TUlKkZLS+jfVet8OHD0vJaGxsFHPnzhV6vV6o1Wrh7+8v5syZIwoLC62ub83v6m72vN30fhkNDQ0iIiJC9OnTR6hUKmEwGMRzzz0nbty4IXUOGRkZYsiQIcLDw0OEhoaKvXv3SptDq3fffVdotVq718aDMq5fvy6WLFki9Hq98PDwEEFBQWL9+vU2vYX9QRmbNm0SAQEBQqVSicDAQPHHP/6x0/cnCiEkf2QfERERPbSc8hwLIiIi6p7YWBAREZE0bCyIiIhIGjYWREREJA0bCyIiIpKGjQURERFJw8aCiIiIpGFjQURERNKwsSAiIiJp2FgQERGRNGwsiIiISJr/Ax+KWVo7QH6tAAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.imshow(rendered[-2])\n",
    "plt.xticks(np.arange(0, rendered[-2].size(1), 1))\n",
    "plt.grid()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading gpt2 model weights from transformers\n"
     ]
    }
   ],
   "source": [
    "from gpt2 import GPT\n",
    "from torch.nn import functional as F\n",
    "\n",
    "model = GPT.from_pretrained(\"gpt2\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "\n",
      "\n",
      "0. tokens:\n",
      " torch.Size([4, 20])\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "1. Logits:\n",
      " torch.Size([4, 20, 50257])\n",
      "sentences * tokens * vocab\n",
      "tensor([[[ -33.5706,  -32.7689,  -35.4510,  ...,  -40.9807,  -40.1867,\n",
      "           -33.2153],\n",
      "         [-105.0330, -103.2583, -110.3991,  ..., -111.8692, -104.8637,\n",
      "          -105.6672],\n",
      "         [-119.8807, -117.7583, -123.3105,  ..., -127.7225, -120.2321,\n",
      "          -120.7401],\n",
      "         ...,\n",
      "         [ -98.9866,  -95.4444,  -97.4640,  ..., -108.0764, -107.3866,\n",
      "          -100.0664],\n",
      "         [-128.3072, -128.1642, -132.5630,  ..., -141.1546, -139.0250,\n",
      "          -128.6488],\n",
      "         [-108.6667, -107.1693, -107.7597,  ..., -113.0093, -113.0028,\n",
      "          -103.4984]],\n",
      "\n",
      "        [[ -33.5706,  -32.7689,  -35.4510,  ...,  -40.9807,  -40.1867,\n",
      "           -33.2153],\n",
      "         [-105.0330, -103.2583, -110.3991,  ..., -111.8692, -104.8637,\n",
      "          -105.6672],\n",
      "         [-119.8807, -117.7583, -123.3105,  ..., -127.7225, -120.2321,\n",
      "          -120.7401],\n",
      "         ...,\n",
      "         [ -81.1260,  -85.5751,  -84.7267,  ...,  -95.5109,  -93.9703,\n",
      "           -84.2436],\n",
      "         [ -75.6481,  -80.6936,  -79.9605,  ...,  -90.5047,  -89.4759,\n",
      "           -79.6117],\n",
      "         [ -70.4409,  -76.0205,  -75.1695,  ...,  -85.7483,  -85.0923,\n",
      "           -75.0535]],\n",
      "\n",
      "        [[ -33.5706,  -32.7689,  -35.4510,  ...,  -40.9807,  -40.1867,\n",
      "           -33.2153],\n",
      "         [-105.0330, -103.2583, -110.3991,  ..., -111.8692, -104.8637,\n",
      "          -105.6672],\n",
      "         [-119.8807, -117.7583, -123.3105,  ..., -127.7225, -120.2321,\n",
      "          -120.7401],\n",
      "         ...,\n",
      "         [-103.5758, -104.5897, -103.9381,  ..., -113.1633, -109.9642,\n",
      "          -102.4687],\n",
      "         [ -92.4719,  -95.6433,  -94.8487,  ..., -105.7557, -102.0913,\n",
      "           -94.0967],\n",
      "         [ -86.1454,  -90.1905,  -89.6241,  ..., -100.0780,  -97.1400,\n",
      "           -88.8879]],\n",
      "\n",
      "        [[ -33.5706,  -32.7689,  -35.4510,  ...,  -40.9807,  -40.1867,\n",
      "           -33.2153],\n",
      "         [-105.0330, -103.2583, -110.3991,  ..., -111.8692, -104.8637,\n",
      "          -105.6672],\n",
      "         [-119.8807, -117.7583, -123.3105,  ..., -127.7225, -120.2321,\n",
      "          -120.7401],\n",
      "         ...,\n",
      "         [-107.3501, -106.5545, -107.0908,  ..., -113.0025, -111.9331,\n",
      "          -102.7872],\n",
      "         [-101.7971, -103.1369, -102.5732,  ..., -112.5466, -109.8921,\n",
      "          -101.0919],\n",
      "         [ -89.9524,  -92.9929,  -92.5344,  ..., -104.0673, -100.7157,\n",
      "           -91.5778]]], grad_fn=<UnsafeViewBackward0>)\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "2. Shift Logits:\n",
      " torch.Size([4, 19, 50257])\n",
      "tensor([[[ -33.5706,  -32.7689,  -35.4510,  ...,  -40.9807,  -40.1867,\n",
      "           -33.2153],\n",
      "         [-105.0330, -103.2583, -110.3991,  ..., -111.8692, -104.8637,\n",
      "          -105.6672],\n",
      "         [-119.8807, -117.7583, -123.3105,  ..., -127.7225, -120.2321,\n",
      "          -120.7401],\n",
      "         ...,\n",
      "         [ -98.9866,  -95.4444,  -97.4640,  ..., -108.0764, -107.3866,\n",
      "          -100.0664],\n",
      "         [-128.3072, -128.1642, -132.5630,  ..., -141.1546, -139.0250,\n",
      "          -128.6488],\n",
      "         [-108.6667, -107.1693, -107.7597,  ..., -113.0093, -113.0028,\n",
      "          -103.4984]],\n",
      "\n",
      "        [[ -33.5706,  -32.7689,  -35.4510,  ...,  -40.9807,  -40.1867,\n",
      "           -33.2153],\n",
      "         [-105.0330, -103.2583, -110.3991,  ..., -111.8692, -104.8637,\n",
      "          -105.6672],\n",
      "         [-119.8807, -117.7583, -123.3105,  ..., -127.7225, -120.2321,\n",
      "          -120.7401],\n",
      "         ...,\n",
      "         [ -81.1260,  -85.5751,  -84.7267,  ...,  -95.5109,  -93.9703,\n",
      "           -84.2436],\n",
      "         [ -75.6481,  -80.6936,  -79.9605,  ...,  -90.5047,  -89.4759,\n",
      "           -79.6117],\n",
      "         [ -70.4409,  -76.0205,  -75.1695,  ...,  -85.7483,  -85.0923,\n",
      "           -75.0535]],\n",
      "\n",
      "        [[ -33.5706,  -32.7689,  -35.4510,  ...,  -40.9807,  -40.1867,\n",
      "           -33.2153],\n",
      "         [-105.0330, -103.2583, -110.3991,  ..., -111.8692, -104.8637,\n",
      "          -105.6672],\n",
      "         [-119.8807, -117.7583, -123.3105,  ..., -127.7225, -120.2321,\n",
      "          -120.7401],\n",
      "         ...,\n",
      "         [-103.5758, -104.5897, -103.9381,  ..., -113.1633, -109.9642,\n",
      "          -102.4687],\n",
      "         [ -92.4719,  -95.6433,  -94.8487,  ..., -105.7557, -102.0913,\n",
      "           -94.0967],\n",
      "         [ -86.1454,  -90.1905,  -89.6241,  ..., -100.0780,  -97.1400,\n",
      "           -88.8879]],\n",
      "\n",
      "        [[ -33.5706,  -32.7689,  -35.4510,  ...,  -40.9807,  -40.1867,\n",
      "           -33.2153],\n",
      "         [-105.0330, -103.2583, -110.3991,  ..., -111.8692, -104.8637,\n",
      "          -105.6672],\n",
      "         [-119.8807, -117.7583, -123.3105,  ..., -127.7225, -120.2321,\n",
      "          -120.7401],\n",
      "         ...,\n",
      "         [-107.3501, -106.5545, -107.0908,  ..., -113.0025, -111.9331,\n",
      "          -102.7872],\n",
      "         [-101.7971, -103.1369, -102.5732,  ..., -112.5466, -109.8921,\n",
      "          -101.0919],\n",
      "         [ -89.9524,  -92.9929,  -92.5344,  ..., -104.0673, -100.7157,\n",
      "           -91.5778]]], grad_fn=<UnsafeViewBackward0>)\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "3. Shift Tokens:\n",
      " torch.Size([4, 19])\n",
      "tensor([[  582,   318,  5586,   319,   257,  9753,    13,   339,   318,  1262,\n",
      "         14441,   284, 14441,   257,  5166,   286,  1341,   271,    13],\n",
      "        [  582,   318,  5586,   319,   257,  9753,    13,   339,   318, 34759,\n",
      "          1241, 19867,   572,    13,     0,     0,     0,     0,     0],\n",
      "        [  582,   318,  5586,   319,   257,  9753,    13,   339,   318,  4769,\n",
      "           257,  6437,  1134,   338, 23441,    13,     0,     0,     0],\n",
      "        [  582,   318,  5586,   319,   257,  9753,    13,   339,  4940, 10427,\n",
      "           510,  9753,   278,   319,   257,  9753,    13,     0,     0]])\n",
      "is contiguous? True\n",
      "1 acc_norm: 0/1=0.0000\n",
      "---\n",
      "Context:\n",
      " A man is sitting on a roof. he\n",
      "Endings:\n",
      "0 (loss: 3.9474) is using wrap to wrap a pair of skis.\n",
      "1 (loss: 5.7732) is ripping level tiles off.\n",
      "2 (loss: 2.9486) is holding a rubik's cube.\n",
      "3 (loss: 4.0702) starts pulling up roofing on a roof.\n",
      "predicted: 2, actual: 3\n"
     ]
    }
   ],
   "source": [
    "num_correct_norm = 0\n",
    "num_correct = 0\n",
    "num_total = 0\n",
    "\n",
    "data, tokens, mask, label = rendered\n",
    "\n",
    "print(\"\\n\\n\\n\\n0. tokens:\\n\", tokens.shape)\n",
    "\n",
    "# get the logits\n",
    "logits, loss = model(tokens)\n",
    "print(\"\\n\\n\\n\\n1. Logits:\\n\", logits.shape)\n",
    "print(\"sentences * tokens * vocab\")\n",
    "print(logits)\n",
    "\n",
    "# evaluate the autoregressive loss at all positions\n",
    "# model does not output the logits for the first token, it is treated as the holy ground truth\n",
    "# so we have to cut the first token -- its got not corresponding logits (cause we are not getting logits for the \"empty start most probable token\")\n",
    "# and we have to cut the last token -- we dont care about further tokens, only want to evaluate loss for the ones given\n",
    "# contiguous() is used to ensure that the memory is laid out in a contiguous block, it returns a copy if tensor is not contiguous or self otherwise\n",
    "#   ways to check if copy: original.data_ptr() == contiguous.data_ptr() or id(original) == id(contiguous)\n",
    "shift_logits = (logits[:, :-1, :]).contiguous()\n",
    "print(\"\\n\\n\\n\\n2. Shift Logits:\\n\", shift_logits.shape)\n",
    "print(logits)\n",
    "\n",
    "shift_tokens = (tokens[:, 1:]).contiguous()\n",
    "print(\"\\n\\n\\n\\n3. Shift Tokens:\\n\", shift_tokens.shape)\n",
    "print(shift_tokens)\n",
    "\n",
    "# now we have to compute the loss with respect to shift_logits\n",
    "flat_shift_logits = shift_logits.view(-1, shift_logits.size(-1))\n",
    "flat_shift_tokens = shift_tokens.view(-1)\n",
    "shift_losses = F.cross_entropy(flat_shift_logits, flat_shift_tokens, reduction='none')\n",
    "shift_losses = shift_losses.view(tokens.size(0), -1)\n",
    "print(\"is contiguous?\", shift_losses.is_contiguous())\n",
    "# now get the average loss just for the completion region (where mask == 1), in each row\n",
    "shift_mask = (mask[..., 1:]).contiguous() # we must shift mask, so we start at the last prompt token\n",
    "masked_shift_losses = shift_losses * shift_mask\n",
    "# sum and divide by the number of 1s in the mask\n",
    "sum_loss = masked_shift_losses.sum(dim=1)\n",
    "avg_loss = sum_loss / shift_mask.sum(dim=1)\n",
    "# now we have a loss for each of the 4 completions\n",
    "# the one with the lowest loss should be the most likely\n",
    "pred = sum_loss.argmin().item()\n",
    "pred_norm = avg_loss.argmin().item()\n",
    "\n",
    "# accumulate stats\n",
    "num_total += 1\n",
    "num_correct += int(pred == label)\n",
    "num_correct_norm += int(pred_norm == label)\n",
    "print(f\"{num_total} acc_norm: {num_correct_norm}/{num_total}={num_correct_norm/num_total:.4f}\")\n",
    "\n",
    "# debug: pretty print a few examples, and the losses in each case\n",
    "if num_total < 10:\n",
    "    print(\"---\")\n",
    "    print(f\"Context:\\n {example['ctx']}\")\n",
    "    print(f\"Endings:\")\n",
    "    for i, end in enumerate(example[\"endings\"]):\n",
    "        print(f\"{i} (loss: {avg_loss[i].item():.4f}) {end}\")\n",
    "    print(f\"predicted: {pred_norm}, actual: {label}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([4, 19, 50257])\n",
      "torch.Size([76, 50257])\n"
     ]
    }
   ],
   "source": [
    "og = shift_logits\n",
    "print(og.shape)\n",
    "ex = shift_logits.view(-1, shift_logits.size(-1))\n",
    "print(ex.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'A man is sitting on a roof. he is using wrap to wrap a pair of skis.'"
      ]
     },
     "execution_count": 45,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import tiktoken\n",
    "\n",
    "enc = tiktoken.get_encoding(\"gpt2\")\n",
    "\n",
    "enc.decode(tokens[0].tolist())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[32]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(tensor([[[ -33.5706,  -32.7689,  -35.4510,  ...,  -40.9807,  -40.1867,\n",
       "            -33.2153],\n",
       "          [-105.0330, -103.2583, -110.3991,  ..., -111.8692, -104.8637,\n",
       "           -105.6672],\n",
       "          [-119.8807, -117.7583, -123.3105,  ..., -127.7225, -120.2321,\n",
       "           -120.7401],\n",
       "          ...,\n",
       "          [ -98.9866,  -95.4444,  -97.4640,  ..., -108.0764, -107.3866,\n",
       "           -100.0664],\n",
       "          [-128.3072, -128.1642, -132.5630,  ..., -141.1546, -139.0250,\n",
       "           -128.6488],\n",
       "          [-108.6667, -107.1693, -107.7597,  ..., -113.0093, -113.0028,\n",
       "           -103.4984]]], grad_fn=<UnsafeViewBackward0>),\n",
       " None)"
      ]
     },
     "execution_count": 57,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "print(enc.encode(\"A\"))\n",
    "model(tokens[0:1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([4, 19])"
      ]
     },
     "execution_count": 65,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "shift_losses.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(7.3022, grad_fn=<MulBackward0>)\n",
      "tensor(7.3022, grad_fn=<NllLossBackward0>)\n"
     ]
    }
   ],
   "source": [
    "import torch \n",
    "\n",
    "def cross_entropy(t, target):\n",
    "    return -1 * torch.log(torch.exp(t[target]) / torch.exp(t).sum())\n",
    "\n",
    "print(cross_entropy(shift_logits[0, 0], shift_tokens[0, 0]))\n",
    "\n",
    "print(F.cross_entropy(shift_logits[0, 0].unsqueeze(0), shift_tokens[0, 0].unsqueeze(0)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[ -33.5706,  -32.7689,  -35.4510,  ...,  -40.9807,  -40.1867,\n",
       "           -33.2153],\n",
       "         [-105.0330, -103.2583, -110.3991,  ..., -111.8692, -104.8637,\n",
       "          -105.6672],\n",
       "         [-119.8807, -117.7583, -123.3105,  ..., -127.7225, -120.2321,\n",
       "          -120.7401],\n",
       "         ...,\n",
       "         [ -98.9866,  -95.4444,  -97.4640,  ..., -108.0764, -107.3866,\n",
       "          -100.0664],\n",
       "         [-128.3072, -128.1642, -132.5630,  ..., -141.1546, -139.0250,\n",
       "          -128.6488],\n",
       "         [-108.6667, -107.1693, -107.7597,  ..., -113.0093, -113.0028,\n",
       "          -103.4984]],\n",
       "\n",
       "        [[ -33.5706,  -32.7689,  -35.4510,  ...,  -40.9807,  -40.1867,\n",
       "           -33.2153],\n",
       "         [-105.0330, -103.2583, -110.3991,  ..., -111.8692, -104.8637,\n",
       "          -105.6672],\n",
       "         [-119.8807, -117.7583, -123.3105,  ..., -127.7225, -120.2321,\n",
       "          -120.7401],\n",
       "         ...,\n",
       "         [ -81.1260,  -85.5751,  -84.7267,  ...,  -95.5109,  -93.9703,\n",
       "           -84.2436],\n",
       "         [ -75.6481,  -80.6936,  -79.9605,  ...,  -90.5047,  -89.4759,\n",
       "           -79.6117],\n",
       "         [ -70.4409,  -76.0205,  -75.1695,  ...,  -85.7483,  -85.0923,\n",
       "           -75.0535]],\n",
       "\n",
       "        [[ -33.5706,  -32.7689,  -35.4510,  ...,  -40.9807,  -40.1867,\n",
       "           -33.2153],\n",
       "         [-105.0330, -103.2583, -110.3991,  ..., -111.8692, -104.8637,\n",
       "          -105.6672],\n",
       "         [-119.8807, -117.7583, -123.3105,  ..., -127.7225, -120.2321,\n",
       "          -120.7401],\n",
       "         ...,\n",
       "         [-103.5758, -104.5897, -103.9381,  ..., -113.1633, -109.9642,\n",
       "          -102.4687],\n",
       "         [ -92.4719,  -95.6433,  -94.8487,  ..., -105.7557, -102.0913,\n",
       "           -94.0967],\n",
       "         [ -86.1454,  -90.1905,  -89.6241,  ..., -100.0780,  -97.1400,\n",
       "           -88.8879]],\n",
       "\n",
       "        [[ -33.5706,  -32.7689,  -35.4510,  ...,  -40.9807,  -40.1867,\n",
       "           -33.2153],\n",
       "         [-105.0330, -103.2583, -110.3991,  ..., -111.8692, -104.8637,\n",
       "          -105.6672],\n",
       "         [-119.8807, -117.7583, -123.3105,  ..., -127.7225, -120.2321,\n",
       "          -120.7401],\n",
       "         ...,\n",
       "         [-107.3501, -106.5545, -107.0908,  ..., -113.0025, -111.9331,\n",
       "          -102.7872],\n",
       "         [-101.7971, -103.1369, -102.5732,  ..., -112.5466, -109.8921,\n",
       "          -101.0919],\n",
       "         [ -89.9524,  -92.9929,  -92.5344,  ..., -104.0673, -100.7157,\n",
       "           -91.5778]]], grad_fn=<UnsafeViewBackward0>)"
      ]
     },
     "execution_count": 75,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "logits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ -33.5706, -105.0330, -119.8807,  ...,  -98.9866, -128.3072,\n",
       "         -108.6667],\n",
       "        [ -32.7689, -103.2583, -117.7583,  ...,  -95.4444, -128.1642,\n",
       "         -107.1693],\n",
       "        [ -35.4510, -110.3991, -123.3105,  ...,  -97.4640, -132.5630,\n",
       "         -107.7597],\n",
       "        ...,\n",
       "        [ -40.9807, -111.8692, -127.7225,  ..., -108.0764, -141.1546,\n",
       "         -113.0093],\n",
       "        [ -40.1867, -104.8637, -120.2321,  ..., -107.3866, -139.0250,\n",
       "         -113.0028],\n",
       "        [ -33.2153, -105.6672, -120.7401,  ..., -100.0664, -128.6488,\n",
       "         -103.4984]], grad_fn=<TransposeBackward0>)"
      ]
     },
     "execution_count": 76,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "logits[0].transpose(1, 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 83,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "is contiguous? True\n",
      "tensor([[1, 4, 7],\n",
      "        [2, 5, 8],\n",
      "        [3, 6, 9]])\n",
      "tensor([1, 2, 3, 4, 5, 6, 7, 8, 9])\n",
      "is contiguous? False\n",
      "tensor([[1, 4, 7],\n",
      "        [2, 5, 8],\n",
      "        [3, 6, 9]])\n",
      "tensor([1, 2, 3, 4, 5, 6, 7, 8, 9])\n"
     ]
    }
   ],
   "source": [
    "# potential contiguous problems\n",
    "\n",
    "x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])\n",
    "xt = x.transpose(1, 0).contiguous()\n",
    "print(\"is contiguous?\", xt.is_contiguous())\n",
    "print(xt)\n",
    "\n",
    "y = x.view(-1)\n",
    "print(y)\n",
    "\n",
    "\n",
    "x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])\n",
    "xt = x.transpose(1, 0)\n",
    "print(\"is contiguous?\", xt.is_contiguous())\n",
    "print(xt)\n",
    "\n",
    "y = x.view(-1)\n",
    "print(y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 87,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "False\n",
      "tensor([[ 1.3367, -0.1229,  1.4617],\n",
      "        [ 1.1288,  0.8137,  1.2674],\n",
      "        [ 1.2345,  3.2082,  1.5349],\n",
      "        [ 1.2303,  0.3620,  1.8094]])\n",
      "True\n",
      "tensor([[ 1.3367, -0.1229,  1.4617],\n",
      "        [ 1.1288,  0.8137,  1.2674],\n",
      "        [ 1.2345,  3.2082,  1.5349],\n",
      "        [ 1.2303,  0.3620,  1.8094]])\n",
      "are they the same?\n",
      " tensor([[True, True, True],\n",
      "        [True, True, True],\n",
      "        [True, True, True],\n",
      "        [True, True, True]])\n"
     ]
    }
   ],
   "source": [
    "# potential contiguous problems\n",
    "\n",
    "torch.manual_seed(42)\n",
    "x1 = torch.randn(3, 4).transpose(0, 1)  # non-contiguous\n",
    "print(x1.is_contiguous())\n",
    "print(x1.add_(1))  # This may not affect all elements as expected\n",
    "\n",
    "torch.manual_seed(42)\n",
    "x2 = torch.randn(3, 4).transpose(0, 1).contiguous()  # contiguous\n",
    "print(x2.is_contiguous())\n",
    "print(x2.add_(1))  # This will affect all elements as expected\n",
    "\n",
    "print(\"are they the same?\\n\", x1 == x2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n",
       "         0.0000e+00, 0.0000e+00, 1.5087e+00, 5.8585e+00, 1.1923e+01, 2.1637e+00,\n",
       "         2.7706e+00, 2.6047e+00, 5.8061e+00, 1.1112e-02, 7.5943e+00, 1.6238e+00,\n",
       "         1.5568e+00],\n",
       "        [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n",
       "         0.0000e+00, 0.0000e+00, 1.5087e+00, 9.2569e+00, 1.4073e+01, 4.3877e+00,\n",
       "         1.3537e+00, 4.0587e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n",
       "         0.0000e+00],\n",
       "        [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n",
       "         0.0000e+00, 0.0000e+00, 1.5087e+00, 3.0817e+00, 8.8085e-01, 1.1994e+01,\n",
       "         3.4136e+00, 6.1874e-01, 5.1059e-01, 1.5803e+00, 0.0000e+00, 0.0000e+00,\n",
       "         0.0000e+00],\n",
       "        [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n",
       "         0.0000e+00, 0.0000e+00, 5.0368e+00, 4.9301e+00, 3.0012e+00, 9.0537e+00,\n",
       "         4.1260e+00, 3.9609e+00, 2.1955e+00, 2.9106e+00, 1.4165e+00, 0.0000e+00,\n",
       "         0.0000e+00]], grad_fn=<MulBackward0>)"
      ]
     },
     "execution_count": 88,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "masked_shift_losses"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 89,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([43.4218, 34.6391, 23.5885, 36.6314], grad_fn=<SumBackward1>),\n",
       " tensor([3.9474, 5.7732, 2.9486, 4.0702], grad_fn=<DivBackward0>))"
      ]
     },
     "execution_count": 89,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sum_loss, avg_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "gpt",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
